// Copyright 1998-2016 Glenn McIntosh
// licensed under the GNU General Public Licence version 3
#pragma once

/** @file lsq.h
	Least Squares Curve Fitting
	*/

// include files
#include "linear.h"
#include <functional>
#include <vector>
#include <array>

namespace math
{
/** real type */
using real = double;

/** array of reals */
using Vector = std::vector<real>;

/** 2D array of reals */
using Vector2D = std::vector<std::vector<real>>;

/** Least Squares Curve Fit. */
template<size_t N> class Lsq
{
	typedef real (*FnI)(real x, int i); /**< Independent variable transform function. */
	typedef std::array<FnI, N> ArrayFnI; /**< Vector of independent variable transform functions. */
	typedef real (*Fn)(real x); /**< Dependent variable transform function. */

	/** Polynomial function.
		Used as the default transform for the independent axis.
		@param x input value
		@param i power index
		@return input value to the power of index
		*/
	static real PolynomialFn(real x, int i)
	{
		real y = 1;
		while (i-- > 0)
			y *= x;
		return y;
	}

	/** Identity function.
		Used as the default transform for the dependent axis.
		@param x input value
		@return input value
		*/
	static real IdentityFn(real x)
	{
		return x;
	}

public:
	/** Create least squares fit object.
		The curve being fitted is of the form fnY(y) = fnX[0](x) + fnX[1](x) + ...
		@param fnX array of functions of the independent variable
		@param fnY function of the independent variable
		@param iFnY inverse function of the independent variable
		*/
	Lsq(ArrayFnI fnX = std::array<FnI, N>{&PolynomialFn, &PolynomialFn}, Fn fnY = &IdentityFn, Fn iFnY = &IdentityFn)
	: mN(fnX.size()), mFnX(fnX), mFnY(fnY), mIFnY(iFnY)
	{
	}

	/** Add point.
		@param x new independent axis value
		@param y new dependent axis value
		@param weight relative importance of point
		*/
	void add(real x, real y, real weight = 1.)
	{
		// fill in sum array
		real yy = mFnY(y);
		Array<N> xx{};
		for (int i = 0; i < mN; ++i)
		{
			xx[i] = mFnX[i](x, i);
			for (int j = 0; j <= i; ++j)
				mXSums[i][j] += xx[i]*xx[j]*weight;
			mYSums[i] += xx[i]*yy*weight;
		}
	}

	/** Add array of points.
		@param x array of independent axis values
		@param y array of dependent axis values
		*/
	void add(Vector x, Vector y)
	{
		int n = x.size();

		// check parameters
		if (x.size() != y.size())
			throw std::invalid_argument("arrays of different sizes");

		// for each element
		for (int i = 0; i < n; ++i)
			add(x[i], y[i]);
	}

	/** Add weighted array of points.
		@param x array of independent axis values
		@param y array of dependent axis values
		@param weight array of relative importance of points
		*/
	void add(Vector x, Vector y, Vector weight)
	{
		int n = x.size();

		// check parameters
		if (x.size() != y.size() || x.size() != weight.size())
			throw std::invalid_argument("arrays of different sizes");

		// for each element
		for (int i = 0; i < n; ++i)
			add(x[i], y[i], weight[i]);
	}

	/** Calculate best fit.
		This calculates a least squares linear fit between the independent and
		dependent variables. The sum of the squares of the differences between the
		predicted value of fnY(y) and the actual value at each point is minimized.
		For non-linear transform functions fnY, this is not the same as a least
		squares fit directly to y (but may be more appropriate).
		*/
	void fit()
	{
		// mirror array across diagonal
		for (int i = 0; i < mN; ++i)
			for (int j = i+1; j < mN; ++j)
				mXSums[i][j] = mXSums[j][i];

		// solve
		real d;
		ArrayIndex<N> index{};
		luDecompose(mXSums, index, d);
		luBackSubstitute(mXSums, index, mYSums);
	}

	/** Evaluate point.
		@param x independent axis value
		@return dependent axis value y
		*/
	real operator()(real x)
	{
		real xSum = 0.;
		for (int i = 0; i < mN; ++i)
			xSum += mYSums[i]*mFnX[i](x, i);
		return mIFnY(xSum);
	}

	/** Evaluate array of points.
		@param x array of independent axis values to be evaluated in place
		*/
	void evaluate(Vector &x)
	{
		for (int i = 0; i < static_cast<int>(x.size()); ++i)
			x[i] = operator()(x[i]);
	}

private:
	int mN;
	Matrix<N,N> mXSums{};
	Array<N> mYSums{};
	ArrayFnI mFnX;
	Fn mFnY, mIFnY;
};
}
